{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Import all the necessary packages for training the graph convolution network, GCN. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "from keras.layers import Input, Dropout\n",
    "from keras.models import Model\n",
    "from keras.optimizers import Adam\n",
    "from keras.regularizers import l2\n",
    "\n",
    "from kegra.layers.graph import GraphConvolution\n",
    "import utils\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET = 'cora'\n",
    "FILTER = 'localpool'  # 'chebyshev'\n",
    "MAX_DEGREE = 2  # maximum polynomial degree\n",
    "SYM_NORM = True  # symmetric (True) vs. left-only (False) normalization\n",
    "NB_EPOCH = 200\n",
    "PATIENCE = 10  # early stopping patience"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load data using function from utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading cora dataset...\n",
      "Dataset has 2708 nodes, 5429 edges, 1433 features.\n"
     ]
    }
   ],
   "source": [
    "X, A, y = utils.load_data(dataset=DATASET)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Split dataset into train, validation, and test batches. The focus here is to split only the labels, not X or A itself. We can sample the training data in X and A by train_mask.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train, y_val, y_test, idx_train, idx_val, idx_test, train_mask = utils.get_splits(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Lets verify the dataset we loaded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X dim=(2708, 1433)\n",
      "A dim=(2708, 2708)\n",
      "y dim=(2708, 7)\n",
      "y_train dim=(2708, 7)\n",
      "y_val dim=(2708, 7)\n",
      "y_test dim=(2708, 7)\n",
      "80\n",
      "[ True  True  True ... False False False]\n"
     ]
    }
   ],
   "source": [
    "print(\"X dim={}\".format(X.shape))\n",
    "print(\"A dim={}\".format(A.shape))\n",
    "print(\"y dim={}\".format(y.shape))\n",
    "\n",
    "print(\"y_train dim={}\".format(y_train.shape))\n",
    "print(\"y_val dim={}\".format(y_val.shape))\n",
    "print(\"y_test dim={}\".format(y_test.shape))\n",
    "\n",
    "tmp_sum = 0\n",
    "for i in range(train_mask.shape[0]):\n",
    "    if train_mask[i] == True:\n",
    "        tmp_sum+=1\n",
    "\n",
    "print(tmp_sum)\n",
    "print(train_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Normalize X so that each one-hot embedded paper has $||x_i|| = 1$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalize X\n",
    "from numpy import linalg as LA\n",
    "for i in range(X.shape[0]):\n",
    "    X[i] /= LA.norm(X[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Take a peek at the 100-th one-hot embedded paper after performing normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "62:0.19999997317790985\n",
      "99:0.19999997317790985\n",
      "132:0.19999997317790985\n",
      "142:0.19999997317790985\n",
      "292:0.19999997317790985\n",
      "402:0.19999997317790985\n",
      "462:0.19999997317790985\n",
      "495:0.19999997317790985\n",
      "507:0.19999997317790985\n",
      "575:0.19999997317790985\n",
      "648:0.19999997317790985\n",
      "675:0.19999997317790985\n",
      "724:0.19999997317790985\n",
      "733:0.19999997317790985\n",
      "778:0.19999997317790985\n",
      "779:0.19999997317790985\n",
      "821:0.19999997317790985\n",
      "1071:0.19999997317790985\n",
      "1097:0.19999997317790985\n",
      "1151:0.19999997317790985\n",
      "1230:0.19999997317790985\n",
      "1331:0.19999997317790985\n",
      "1334:0.19999997317790985\n",
      "1348:0.19999997317790985\n",
      "1422:0.19999997317790985\n",
      "tmp_sum=[[0.9999994]]==1\n"
     ]
    }
   ],
   "source": [
    "utils.look_sparse_matrix(X, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using local pooling filters...\n"
     ]
    }
   ],
   "source": [
    "if FILTER == 'localpool':\n",
    "    \"\"\" Local pooling filters (see 'renormalization trick' in Kipf & Welling, arXiv 2016) \"\"\"\n",
    "    print('Using local pooling filters...')\n",
    "    A_hat = utils.preprocess_adj(A, SYM_NORM)\n",
    "    support = 1\n",
    "    graph = [X, A_hat]\n",
    "elif FILTER == 'chebyshev':\n",
    "    \"\"\" Chebyshev polynomial basis filters (Defferard et al., NIPS 2016)  \"\"\"\n",
    "    print('Using Chebyshev polynomial basis filters...')\n",
    "    L = normalized_laplacian(A, SYM_NORM)\n",
    "    L_scaled = rescale_laplacian(L)\n",
    "    T_k = chebyshev_polynomial(L_scaled, MAX_DEGREE)\n",
    "    support = MAX_DEGREE + 1\n",
    "    graph = [X]+T_k\n",
    "    \n",
    "else:\n",
    "    raise Exception('Invalid filter type.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:423: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    }
   ],
   "source": [
    "#Define input tensor\n",
    "A_hat_in = Input(shape=A_hat.shape, batch_shape=(None, None), sparse=True)\n",
    "X_in = Input(shape=(X.shape[1],))\n",
    "\n",
    "# Define model architecture\n",
    "# NOTE: We pass arguments for graph convolutional layers as a list of tensors.\n",
    "# This is somewhat hacky, more elegant options would require rewriting the Layer base class.\n",
    "H_1 = Dropout(0.5)(X_in)\n",
    "H_2 = GraphConvolution(16, support, activation='relu', kernel_regularizer=l2(5e-4))([H_1]+[A_hat_in])\n",
    "H_2 = Dropout(0.5)(H_2)\n",
    "Y_out = GraphConvolution(y.shape[1], support, activation='softmax')([H_2]+[A_hat_in]) #H_{l+1} = f_act(H_l, A)\n",
    "\n",
    "# Compile model\n",
    "model = Model(inputs=[X_in]+[A_hat_in], outputs=Y_out)\n",
    "model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.01))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reminders\n",
    "#### line 4: $\\hat{A} = \\tilde{D}^{-\\frac{1}{2}} \\tilde{A} \\tilde{D}^{\\frac{1}{2}}$, where $\\tilde{A} = A + I_N$, and $\\tilde{D}_{ii} = \\sum_j \\tilde{A}_{ij}$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Lets take a look at the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_2 (InputLayer)            (None, 1433)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 1433)         0           input_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "input_1 (InputLayer)            (None, None)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "graph_convolution_1 (GraphConvo (None, 16)           22944       dropout_1[0][0]                  \n",
      "                                                                 input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_2 (Dropout)             (None, 16)           0           graph_convolution_1[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "graph_convolution_2 (GraphConvo (None, 7)            119         dropout_2[0][0]                  \n",
      "                                                                 input_1[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 23,063\n",
      "Trainable params: 23,063\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper variables for main training loop\n",
    "wait = 0\n",
    "preds = None\n",
    "best_val_loss = 99999"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 train_loss= 0.0846 train_acc= 1.0000 val_loss= 0.7659 val_acc= 0.7905 time= 0.0702\n",
      "Epoch: 0002 train_loss= 0.0837 train_acc= 1.0000 val_loss= 0.7634 val_acc= 0.7905 time= 0.0773\n",
      "Epoch: 0003 train_loss= 0.0828 train_acc= 1.0000 val_loss= 0.7599 val_acc= 0.7881 time= 0.0814\n",
      "Epoch: 0004 train_loss= 0.0820 train_acc= 1.0000 val_loss= 0.7551 val_acc= 0.7881 time= 0.0775\n",
      "Epoch: 0005 train_loss= 0.0810 train_acc= 1.0000 val_loss= 0.7509 val_acc= 0.7881 time= 0.0718\n",
      "Epoch: 0006 train_loss= 0.0800 train_acc= 1.0000 val_loss= 0.7485 val_acc= 0.7881 time= 0.0578\n",
      "Epoch: 0007 train_loss= 0.0791 train_acc= 1.0000 val_loss= 0.7464 val_acc= 0.7857 time= 0.0711\n",
      "Epoch: 0008 train_loss= 0.0781 train_acc= 1.0000 val_loss= 0.7452 val_acc= 0.7881 time= 0.0700\n",
      "Epoch: 0009 train_loss= 0.0774 train_acc= 1.0000 val_loss= 0.7444 val_acc= 0.7881 time= 0.0832\n",
      "Epoch: 0010 train_loss= 0.0766 train_acc= 1.0000 val_loss= 0.7444 val_acc= 0.7881 time= 0.0731\n",
      "Epoch: 0011 train_loss= 0.0759 train_acc= 1.0000 val_loss= 0.7440 val_acc= 0.7857 time= 0.0743\n",
      "Epoch: 0012 train_loss= 0.0751 train_acc= 1.0000 val_loss= 0.7451 val_acc= 0.7857 time= 0.0619\n",
      "Epoch: 0013 train_loss= 0.0745 train_acc= 1.0000 val_loss= 0.7461 val_acc= 0.7833 time= 0.0633\n",
      "Epoch: 0014 train_loss= 0.0743 train_acc= 1.0000 val_loss= 0.7501 val_acc= 0.7810 time= 0.0724\n",
      "Epoch: 0015 train_loss= 0.0740 train_acc= 1.0000 val_loss= 0.7533 val_acc= 0.7786 time= 0.0761\n",
      "Epoch: 0016 train_loss= 0.0734 train_acc= 1.0000 val_loss= 0.7558 val_acc= 0.7762 time= 0.0774\n",
      "Epoch: 0017 train_loss= 0.0730 train_acc= 1.0000 val_loss= 0.7579 val_acc= 0.7738 time= 0.0870\n",
      "Epoch: 0018 train_loss= 0.0723 train_acc= 1.0000 val_loss= 0.7574 val_acc= 0.7738 time= 0.0911\n",
      "Epoch: 0019 train_loss= 0.0717 train_acc= 1.0000 val_loss= 0.7548 val_acc= 0.7810 time= 0.0802\n",
      "Epoch: 0020 train_loss= 0.0712 train_acc= 1.0000 val_loss= 0.7522 val_acc= 0.7810 time= 0.0719\n",
      "Epoch: 0021 train_loss= 0.0709 train_acc= 1.0000 val_loss= 0.7512 val_acc= 0.7810 time= 0.0854\n",
      "Epoch: 0022 train_loss= 0.0709 train_acc= 1.0000 val_loss= 0.7503 val_acc= 0.7786 time= 0.0735\n",
      "Epoch 22: early stopping\n",
      "Test set results: loss= 0.7900 accuracy= 0.7758\n"
     ]
    }
   ],
   "source": [
    "# Fit\n",
    "for epoch in range(1, NB_EPOCH+1):\n",
    "\n",
    "    # Log wall-clock time\n",
    "    t = time.time()\n",
    "\n",
    "    # Single training iteration (we mask nodes without labels for loss calculation)\n",
    "    model.fit(graph, y_train, sample_weight=train_mask,\n",
    "              batch_size=A.shape[0], epochs=1, shuffle=False, verbose=0)\n",
    "\n",
    "    # Predict on full dataset\n",
    "    preds = model.predict(graph, batch_size=A.shape[0])\n",
    "    \n",
    "    # Train / validation scores\n",
    "    train_val_loss, train_val_acc = utils.evaluate_preds(preds, [y_train, y_val],\n",
    "                                                   [idx_train, idx_val])\n",
    "    print(\"Epoch: {:04d}\".format(epoch),\n",
    "          \"train_loss= {:.4f}\".format(train_val_loss[0]),\n",
    "          \"train_acc= {:.4f}\".format(train_val_acc[0]),\n",
    "          \"val_loss= {:.4f}\".format(train_val_loss[1]),\n",
    "          \"val_acc= {:.4f}\".format(train_val_acc[1]),\n",
    "          \"time= {:.4f}\".format(time.time() - t))\n",
    "\n",
    "    # Early stopping\n",
    "    if train_val_loss[1] < best_val_loss:\n",
    "        best_val_loss = train_val_loss[1]\n",
    "        wait = 0\n",
    "    else:\n",
    "        if wait >= PATIENCE:\n",
    "            print('Epoch {}: early stopping'.format(epoch))\n",
    "            break\n",
    "        wait += 1\n",
    "\n",
    "# Testing\n",
    "test_loss, test_acc = utils.evaluate_preds(preds, [y_test], [idx_test])\n",
    "print(\"Test set results:\",\n",
    "      \"loss= {:.4f}\".format(test_loss[0]),\n",
    "      \"accuracy= {:.4f}\".format(test_acc[0]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
